{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "start_time": "2023-08-22T14:22:18.843573Z",
     "end_time": "2023-08-22T14:22:19.472395Z"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import GPT2Tokenizer\n",
    "\n",
    "tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n",
    "\n",
    "tokenizer.pad_token = \"[PAD]\"\n",
    "\n",
    "def preprocess(text):\n",
    "    encoding = tokenizer.encode_plus(text, truncation=True, padding=True, return_tensors='pt')\n",
    "    input_ids = encoding['input_ids']\n",
    "    attention_mask = encoding['attention_mask']\n",
    "    return input_ids, attention_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "outputs": [],
   "source": [
    "from transformers import GPT2LMHeadModel, GPT2Config\n",
    "\n",
    "config = GPT2Config.from_pretrained('gpt2')\n",
    "model = GPT2LMHeadModel.from_pretrained('gpt2', config=config)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-08-22T14:22:19.474395Z",
     "end_time": "2023-08-22T14:22:21.229663Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def train(model, optimizer, epochs, training_loader):\n",
    "    model.train()\n",
    "\n",
    "    for epoch in range(epochs):\n",
    "        loss = 0\n",
    "        for i, batch in enumerate(training_loader):\n",
    "            inputs, targets = batch\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(inputs, labels=targets)\n",
    "            loss = outputs.loss\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            print('Epoch {:2d} | Batch {:2d}/{:2d} | Loss {:.4f}'.format(epoch+1, i+1, len(training_loader), loss))\n",
    "\n",
    "        torch.save(model.state_dict(), f'checkpoint_epoch_{epoch+1}.pt')"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-08-22T14:22:21.231664Z",
     "end_time": "2023-08-22T14:22:21.247328Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-08-22T14:22:21.247328Z",
     "end_time": "2023-08-22T14:22:21.280337Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "from torch.utils.data import Dataset, DataLoader\n",
    "import pandas as pd\n",
    "\n",
    "class NewsDataset(Dataset):\n",
    "    def __init__(self, data):\n",
    "        self.data = data\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        input_ids, attention_mask = preprocess(self.data.iloc[index]['headline_text'])\n",
    "        input_ids = torch.tensor(input_ids).squeeze()\n",
    "        attention_mask = torch.tensor(attention_mask).squeeze()\n",
    "        return input_ids, attention_mask\n",
    "\n",
    "dataset = pd.read_csv('common_datasets/abcnews-date-text.csv')[['headline_text']]\n",
    "\n",
    "training_set = NewsDataset(dataset[:50000])\n",
    "training_loader = DataLoader(training_set, batch_size=8)\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "print(device)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-08-22T14:22:21.278331Z",
     "end_time": "2023-08-22T14:22:21.338402Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\anaconda3\\lib\\site-packages\\transformers\\optimization.py:407: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  warnings.warn(\n",
      "C:\\Users\\gqhua\\AppData\\Local\\Temp\\ipykernel_12424\\2116918423.py:13: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  input_ids = torch.tensor(input_ids).squeeze()\n",
      "C:\\Users\\gqhua\\AppData\\Local\\Temp\\ipykernel_12424\\2116918423.py:14: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  attention_mask = torch.tensor(attention_mask).squeeze()\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "stack expects each tensor to be equal size, but got [6] at entry 0 and [8] at entry 1",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mRuntimeError\u001B[0m                              Traceback (most recent call last)",
      "Cell \u001B[1;32mIn[19], line 4\u001B[0m\n\u001B[0;32m      1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtransformers\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m AdamW\n\u001B[0;32m      2\u001B[0m optimizer \u001B[38;5;241m=\u001B[39m AdamW(model\u001B[38;5;241m.\u001B[39mparameters(), lr\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1e-5\u001B[39m)\n\u001B[1;32m----> 4\u001B[0m \u001B[43mtrain\u001B[49m\u001B[43m(\u001B[49m\u001B[43mmodel\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43moptimizer\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m3\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtraining_loader\u001B[49m\u001B[43m)\u001B[49m\n",
      "Cell \u001B[1;32mIn[17], line 8\u001B[0m, in \u001B[0;36mtrain\u001B[1;34m(model, optimizer, epochs, training_loader)\u001B[0m\n\u001B[0;32m      6\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m epoch \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(epochs):\n\u001B[0;32m      7\u001B[0m     loss \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0\u001B[39m\n\u001B[1;32m----> 8\u001B[0m     \u001B[38;5;28;01mfor\u001B[39;00m i, batch \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28menumerate\u001B[39m(training_loader):\n\u001B[0;32m      9\u001B[0m         inputs, targets \u001B[38;5;241m=\u001B[39m batch\n\u001B[0;32m     10\u001B[0m         inputs, targets \u001B[38;5;241m=\u001B[39m inputs\u001B[38;5;241m.\u001B[39mto(device), targets\u001B[38;5;241m.\u001B[39mto(device)\n",
      "File \u001B[1;32mD:\\anaconda3\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:634\u001B[0m, in \u001B[0;36m_BaseDataLoaderIter.__next__\u001B[1;34m(self)\u001B[0m\n\u001B[0;32m    631\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_sampler_iter \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m    632\u001B[0m     \u001B[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001B[39;00m\n\u001B[0;32m    633\u001B[0m     \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_reset()  \u001B[38;5;66;03m# type: ignore[call-arg]\u001B[39;00m\n\u001B[1;32m--> 634\u001B[0m data \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_next_data\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m    635\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_num_yielded \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1\u001B[39m\n\u001B[0;32m    636\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_dataset_kind \u001B[38;5;241m==\u001B[39m _DatasetKind\u001B[38;5;241m.\u001B[39mIterable \u001B[38;5;129;01mand\u001B[39;00m \\\n\u001B[0;32m    637\u001B[0m         \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_IterableDataset_len_called \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;129;01mand\u001B[39;00m \\\n\u001B[0;32m    638\u001B[0m         \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_num_yielded \u001B[38;5;241m>\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_IterableDataset_len_called:\n",
      "File \u001B[1;32mD:\\anaconda3\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:678\u001B[0m, in \u001B[0;36m_SingleProcessDataLoaderIter._next_data\u001B[1;34m(self)\u001B[0m\n\u001B[0;32m    676\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21m_next_data\u001B[39m(\u001B[38;5;28mself\u001B[39m):\n\u001B[0;32m    677\u001B[0m     index \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_next_index()  \u001B[38;5;66;03m# may raise StopIteration\u001B[39;00m\n\u001B[1;32m--> 678\u001B[0m     data \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_dataset_fetcher\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfetch\u001B[49m\u001B[43m(\u001B[49m\u001B[43mindex\u001B[49m\u001B[43m)\u001B[49m  \u001B[38;5;66;03m# may raise StopIteration\u001B[39;00m\n\u001B[0;32m    679\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_pin_memory:\n\u001B[0;32m    680\u001B[0m         data \u001B[38;5;241m=\u001B[39m _utils\u001B[38;5;241m.\u001B[39mpin_memory\u001B[38;5;241m.\u001B[39mpin_memory(data, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_pin_memory_device)\n",
      "File \u001B[1;32mD:\\anaconda3\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py:54\u001B[0m, in \u001B[0;36m_MapDatasetFetcher.fetch\u001B[1;34m(self, possibly_batched_index)\u001B[0m\n\u001B[0;32m     52\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m     53\u001B[0m     data \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mdataset[possibly_batched_index]\n\u001B[1;32m---> 54\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mcollate_fn\u001B[49m\u001B[43m(\u001B[49m\u001B[43mdata\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[1;32mD:\\anaconda3\\lib\\site-packages\\torch\\utils\\data\\_utils\\collate.py:264\u001B[0m, in \u001B[0;36mdefault_collate\u001B[1;34m(batch)\u001B[0m\n\u001B[0;32m    203\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mdefault_collate\u001B[39m(batch):\n\u001B[0;32m    204\u001B[0m \u001B[38;5;250m    \u001B[39m\u001B[38;5;124mr\u001B[39m\u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m    205\u001B[0m \u001B[38;5;124;03m        Function that takes in a batch of data and puts the elements within the batch\u001B[39;00m\n\u001B[0;32m    206\u001B[0m \u001B[38;5;124;03m        into a tensor with an additional outer dimension - batch size. The exact output type can be\u001B[39;00m\n\u001B[1;32m   (...)\u001B[0m\n\u001B[0;32m    262\u001B[0m \u001B[38;5;124;03m            >>> default_collate(batch)  # Handle `CustomType` automatically\u001B[39;00m\n\u001B[0;32m    263\u001B[0m \u001B[38;5;124;03m    \"\"\"\u001B[39;00m\n\u001B[1;32m--> 264\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mcollate\u001B[49m\u001B[43m(\u001B[49m\u001B[43mbatch\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcollate_fn_map\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mdefault_collate_fn_map\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[1;32mD:\\anaconda3\\lib\\site-packages\\torch\\utils\\data\\_utils\\collate.py:142\u001B[0m, in \u001B[0;36mcollate\u001B[1;34m(batch, collate_fn_map)\u001B[0m\n\u001B[0;32m    139\u001B[0m transposed \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mlist\u001B[39m(\u001B[38;5;28mzip\u001B[39m(\u001B[38;5;241m*\u001B[39mbatch))  \u001B[38;5;66;03m# It may be accessed twice, so we use a list.\u001B[39;00m\n\u001B[0;32m    141\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(elem, \u001B[38;5;28mtuple\u001B[39m):\n\u001B[1;32m--> 142\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m [collate(samples, collate_fn_map\u001B[38;5;241m=\u001B[39mcollate_fn_map) \u001B[38;5;28;01mfor\u001B[39;00m samples \u001B[38;5;129;01min\u001B[39;00m transposed]  \u001B[38;5;66;03m# Backwards compatibility.\u001B[39;00m\n\u001B[0;32m    143\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m    144\u001B[0m     \u001B[38;5;28;01mtry\u001B[39;00m:\n",
      "File \u001B[1;32mD:\\anaconda3\\lib\\site-packages\\torch\\utils\\data\\_utils\\collate.py:142\u001B[0m, in \u001B[0;36m<listcomp>\u001B[1;34m(.0)\u001B[0m\n\u001B[0;32m    139\u001B[0m transposed \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mlist\u001B[39m(\u001B[38;5;28mzip\u001B[39m(\u001B[38;5;241m*\u001B[39mbatch))  \u001B[38;5;66;03m# It may be accessed twice, so we use a list.\u001B[39;00m\n\u001B[0;32m    141\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(elem, \u001B[38;5;28mtuple\u001B[39m):\n\u001B[1;32m--> 142\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m [\u001B[43mcollate\u001B[49m\u001B[43m(\u001B[49m\u001B[43msamples\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcollate_fn_map\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mcollate_fn_map\u001B[49m\u001B[43m)\u001B[49m \u001B[38;5;28;01mfor\u001B[39;00m samples \u001B[38;5;129;01min\u001B[39;00m transposed]  \u001B[38;5;66;03m# Backwards compatibility.\u001B[39;00m\n\u001B[0;32m    143\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m    144\u001B[0m     \u001B[38;5;28;01mtry\u001B[39;00m:\n",
      "File \u001B[1;32mD:\\anaconda3\\lib\\site-packages\\torch\\utils\\data\\_utils\\collate.py:119\u001B[0m, in \u001B[0;36mcollate\u001B[1;34m(batch, collate_fn_map)\u001B[0m\n\u001B[0;32m    117\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m collate_fn_map \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m    118\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m elem_type \u001B[38;5;129;01min\u001B[39;00m collate_fn_map:\n\u001B[1;32m--> 119\u001B[0m         \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mcollate_fn_map\u001B[49m\u001B[43m[\u001B[49m\u001B[43melem_type\u001B[49m\u001B[43m]\u001B[49m\u001B[43m(\u001B[49m\u001B[43mbatch\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcollate_fn_map\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mcollate_fn_map\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m    121\u001B[0m     \u001B[38;5;28;01mfor\u001B[39;00m collate_type \u001B[38;5;129;01min\u001B[39;00m collate_fn_map:\n\u001B[0;32m    122\u001B[0m         \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(elem, collate_type):\n",
      "File \u001B[1;32mD:\\anaconda3\\lib\\site-packages\\torch\\utils\\data\\_utils\\collate.py:162\u001B[0m, in \u001B[0;36mcollate_tensor_fn\u001B[1;34m(batch, collate_fn_map)\u001B[0m\n\u001B[0;32m    160\u001B[0m     storage \u001B[38;5;241m=\u001B[39m elem\u001B[38;5;241m.\u001B[39m_typed_storage()\u001B[38;5;241m.\u001B[39m_new_shared(numel, device\u001B[38;5;241m=\u001B[39melem\u001B[38;5;241m.\u001B[39mdevice)\n\u001B[0;32m    161\u001B[0m     out \u001B[38;5;241m=\u001B[39m elem\u001B[38;5;241m.\u001B[39mnew(storage)\u001B[38;5;241m.\u001B[39mresize_(\u001B[38;5;28mlen\u001B[39m(batch), \u001B[38;5;241m*\u001B[39m\u001B[38;5;28mlist\u001B[39m(elem\u001B[38;5;241m.\u001B[39msize()))\n\u001B[1;32m--> 162\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mstack\u001B[49m\u001B[43m(\u001B[49m\u001B[43mbatch\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m0\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mout\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mout\u001B[49m\u001B[43m)\u001B[49m\n",
      "\u001B[1;31mRuntimeError\u001B[0m: stack expects each tensor to be equal size, but got [6] at entry 0 and [8] at entry 1"
     ]
    }
   ],
   "source": [
    "from transformers import AdamW\n",
    "optimizer = AdamW(model.parameters(), lr=1e-5)\n",
    "\n",
    "train(model, optimizer, 3, training_loader)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "generated = model.generate(\n",
    "    input_ids = torch.tensor(tokenizer.encode(\"Australian politicians\", add_special_tokens=True)).unsqueeze(0),\n",
    "    max_length = 50,\n",
    "    temperature = 1.0,\n",
    "    top_k = 0,\n",
    "    top_p = 0.9,\n",
    "    do_sample=True,\n",
    "    num_return_sequences=1\n",
    ")"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "generated_title = tokenizer.decode(generated[0], skip_special_tokens=True)\n",
    "print(generated_title)"
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
